{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "355UKMUQJxFd"
   },
   "source": [
    "# Scalable Diffusion Models with Transformer (DiT)\n",
    "\n",
    "This notebook samples from pre-trained DiT models. DiTs are class-conditional latent diffusion models trained on ImageNet that use transformers in place of U-Nets as the DDPM backbone. DiT outperforms all prior diffusion models on the ImageNet benchmarks.\n",
    "\n",
    "[Project Page](https://www.wpeebles.com/DiT) | [HuggingFace Space](https://huggingface.co/spaces/wpeebles/DiT) | [Paper](http://arxiv.org/abs/2212.09748) | [GitHub](github.com/facebookresearch/DiT)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zJlgLkSaKn7u"
   },
   "source": [
    "# 1. Setup\n",
    "\n",
    "We recommend using GPUs (Runtime > Change runtime type > Hardware accelerator > GPU). Run this cell to clone the DiT GitHub repo and setup PyTorch. You only have to run this once."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://github.com/facebookresearch/DiT.git\n",
    "import DiT, os\n",
    "os.chdir('DiT')\n",
    "os.environ['PYTHONPATH'] = '/env/python:/content/DiT'\n",
    "!pip install diffusers timm --upgrade\n",
    "# DiT imports:\n",
    "import torch\n",
    "from torchvision.utils import save_image\n",
    "from diffusion import create_diffusion\n",
    "from diffusers.models import AutoencoderKL\n",
    "from download import find_model\n",
    "from models import DiT_XL_2\n",
    "from PIL import Image\n",
    "from IPython.display import display\n",
    "torch.set_grad_enabled(False)\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "if device == \"cpu\":\n",
    "    print(\"GPU not found. Using CPU instead.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "AXpziRkoOvV9"
   },
   "source": [
    "# Download DiT-XL/2 Models\n",
    "\n",
    "You can choose between a 512x512 model and a 256x256 model. You can swap-out the LDM VAE, too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EWG-WNimO59K"
   },
   "outputs": [],
   "source": [
    "image_size = 256 #@param [256, 512]\n",
    "vae_model = \"stabilityai/sd-vae-ft-ema\" #@param [\"stabilityai/sd-vae-ft-mse\", \"stabilityai/sd-vae-ft-ema\"]\n",
    "latent_size = int(image_size) // 8\n",
    "# Load model:\n",
    "model = DiT_XL_2(input_size=latent_size).to(device)\n",
    "state_dict = find_model(f\"DiT-XL-2-{image_size}x{image_size}.pt\")\n",
    "model.load_state_dict(state_dict)\n",
    "model.eval() # important!\n",
    "vae = AutoencoderKL.from_pretrained(vae_model).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5JTNyzNZKb9E"
   },
   "source": [
    "# 2. Sample from Pre-trained DiT Models\n",
    "\n",
    "You can customize several sampling options. For the full list of ImageNet classes, [check out this](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-Hw7B5h4Kk4p"
   },
   "outputs": [],
   "source": [
    "# Set user inputs:\n",
    "seed = 0 #@param {type:\"number\"}\n",
    "torch.manual_seed(seed)\n",
    "num_sampling_steps = 250 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
    "cfg_scale = 4 #@param {type:\"slider\", min:1, max:10, step:0.1}\n",
    "class_labels = 207, 360, 387, 974, 88, 979, 417, 279 #@param {type:\"raw\"}\n",
    "samples_per_row = 4 #@param {type:\"number\"}\n",
    "\n",
    "# Create diffusion object:\n",
    "diffusion = create_diffusion(str(num_sampling_steps))\n",
    "\n",
    "# Create sampling noise:\n",
    "n = len(class_labels)\n",
    "z = torch.randn(n, 4, latent_size, latent_size, device=device)\n",
    "y = torch.tensor(class_labels, device=device)\n",
    "\n",
    "# Setup classifier-free guidance:\n",
    "z = torch.cat([z, z], 0)\n",
    "y_null = torch.tensor([1000] * n, device=device)\n",
    "y = torch.cat([y, y_null], 0)\n",
    "model_kwargs = dict(y=y, cfg_scale=cfg_scale)\n",
    "\n",
    "# Sample images:\n",
    "samples = diffusion.p_sample_loop(\n",
    "    model.forward_with_cfg, z.shape, z, clip_denoised=False, \n",
    "    model_kwargs=model_kwargs, progress=True, device=device\n",
    ")\n",
    "samples, _ = samples.chunk(2, dim=0)  # Remove null class samples\n",
    "samples = vae.decode(samples / 0.18215).sample\n",
    "\n",
    "# Save and display images:\n",
    "save_image(samples, \"sample.png\", nrow=int(samples_per_row), \n",
    "           normalize=True, value_range=(-1, 1))\n",
    "samples = Image.open(\"sample.png\")\n",
    "display(samples)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
